{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from collections import Counter\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Node:\n",
    "    def __init__(self, x=None, label=None, y=None, data=None):\n",
    "        self.label = label   \n",
    "        self.x = x           \n",
    "        self.child = []      \n",
    "        self.y = y          \n",
    "        self.data = data   \n",
    "\n",
    "    def append(self, node):  \n",
    "        self.child.append(node)\n",
    "\n",
    "    def predict(self, features): \n",
    "        if self.y is not None:\n",
    "            return self.y\n",
    "        for c in self.child:\n",
    "            if c.x == features[self.label]:\n",
    "                return c.predict(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def printnode(node, depth=0):  \n",
    "    if node.label is None:\n",
    "        print(depth, (node.label, node.x, node.y, len(node.data)))\n",
    "    else:\n",
    "        print(depth, (node.label, node.x))\n",
    "        for c in node.child:\n",
    "            printnode(c, depth+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DTree:\n",
    "    def __init__(self, epsilon=0, alpha=0): \n",
    "        self.epsilon = epsilon\n",
    "        self.alpha = alpha\n",
    "        self.tree = Node()\n",
    "\n",
    "    def prob(self, datasets): \n",
    "        datalen = len(datasets)\n",
    "        labelx = set(datasets)\n",
    "        p = {l: 0 for l in labelx}\n",
    "        for d in datasets:\n",
    "            p[d] += 1\n",
    "        for i in p.items():\n",
    "            p[i[0]] /= datalen\n",
    "        return p\n",
    "\n",
    "    def calc_ent(self, datasets): \n",
    "        p = self.prob(datasets)\n",
    "        ent = sum([-v * math.log(v, 2) for v in p.values()])\n",
    "        return ent\n",
    "\n",
    "    def cond_ent(self, datasets, col):  \n",
    "        labelx = set(datasets.iloc[col])\n",
    "        p = {x: [] for x in labelx}\n",
    "        for i, d in enumerate(datasets.iloc[-1]):\n",
    "            p[datasets.iloc[col][i]].append(d)\n",
    "        return sum([self.prob(datasets.iloc[col])[k] * self.calc_ent(p[k]) for k in p.keys()])\n",
    "\n",
    "    def info_gain_train(self, datasets, datalabels): \n",
    "        datasets = datasets.T\n",
    "        ent = self.calc_ent(datasets.iloc[-1])\n",
    "        gainmax = {}\n",
    "        for i in range(len(datasets) - 1):\n",
    "            cond = self.cond_ent(datasets, i)\n",
    "            gainmax[ent - cond] = i\n",
    "        m = max(gainmax.keys())\n",
    "        return gainmax[m], m\n",
    "\n",
    "    def train(self, datasets, node):\n",
    "        labely = datasets.columns[-1]\n",
    "        if len(datasets[labely].value_counts()) == 1:\n",
    "            node.data = datasets[labely]\n",
    "            node.y = datasets[labely][0]\n",
    "            return\n",
    "        if len(datasets.columns[:-1]) == 0:\n",
    "            node.data = datasets[labely]\n",
    "            node.y = datasets[labely].value_counts().index[0]\n",
    "            return\n",
    "        gainmaxi, gainmax = self.info_gain_train(datasets, datasets.columns)\n",
    "        if gainmax <= self.epsilon: \n",
    "            node.data = datasets[labely]\n",
    "            node.y = datasets[labely].value_counts().index[0]\n",
    "            return\n",
    "\n",
    "        vc = datasets[datasets.columns[gainmaxi]].value_counts()\n",
    "        for Di in vc.index:\n",
    "            node.label = gainmaxi\n",
    "            child = Node(Di)\n",
    "            node.append(child)\n",
    "            new_datasets = pd.DataFrame([list(i) for i in datasets.values if i[gainmaxi]==Di], columns=datasets.columns)\n",
    "            self.train(new_datasets, child)\n",
    "\n",
    "    def fit(self, datasets):\n",
    "        self.train(datasets, self.tree)\n",
    "\n",
    "    def findleaf(self, node, leaf): \n",
    "        for t in node.child:\n",
    "            if t.y is not None:\n",
    "                leaf.append(t.data)\n",
    "            else:\n",
    "                for c in node.child:\n",
    "                    self.findleaf(c, leaf)\n",
    "\n",
    "    def findfather(self, node, errormin):\n",
    "        if node.label is not None:\n",
    "            cy = [c.y for c in node.child]\n",
    "            if None not in cy:  \n",
    "                childdata = []\n",
    "                for c in node.child:\n",
    "                    for d in list(c.data):\n",
    "                        childdata.append(d)\n",
    "                childcounter = Counter(childdata)\n",
    "\n",
    "                old_child = node.child  \n",
    "                old_label = node.label\n",
    "                old_y = node.y\n",
    "                old_data = node.data\n",
    "\n",
    "                node.label = None  \n",
    "                node.y = childcounter.most_common(1)[0][0]\n",
    "                node.data = childdata\n",
    "\n",
    "                error = self.c_error()\n",
    "                if error <= errormin:  \n",
    "                    errormin = error\n",
    "                    return 1\n",
    "                else:\n",
    "                    node.child = old_child  \n",
    "                    node.label = old_label\n",
    "                    node.y = old_y\n",
    "                    node.data = old_data\n",
    "            else:\n",
    "                re = 0\n",
    "                i = 0\n",
    "                while i < len(node.child):\n",
    "                    if_re = self.findfather(node.child[i], errormin)  \n",
    "                    if if_re == 1:\n",
    "                        re = 1\n",
    "                    elif if_re == 2:\n",
    "                        i -= 1\n",
    "                    i += 1\n",
    "                if re:\n",
    "                    return 2\n",
    "        return 0\n",
    "\n",
    "    def c_error(self): \n",
    "        leaf = []\n",
    "        self.findleaf(self.tree, leaf)\n",
    "        leafnum = [len(l) for l in leaf]\n",
    "        ent = [self.calc_ent(l) for l in leaf]\n",
    "        print(\"Ent:\", ent)\n",
    "        error = self.alpha*len(leafnum)\n",
    "        for l, e in zip(leafnum, ent):\n",
    "            error += l*e\n",
    "        print(\"C(T):\", error)\n",
    "        return error\n",
    "\n",
    "    def cut(self, alpha=0): \n",
    "        if alpha:\n",
    "            self.alpha = alpha\n",
    "        errormin = self.c_error()\n",
    "        self.findfather(self.tree, errormin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = np.array([['青年', '否', '否', '一般', '否'],\n",
    "               ['青年', '否', '否', '好', '否'],\n",
    "               ['青年', '是', '否', '好', '是'],\n",
    "               ['青年', '是', '是', '一般', '是'],\n",
    "               ['青年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '好', '否'],\n",
    "               ['中年', '是', '是', '好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '好', '是'],\n",
    "               ['老年', '是', '否', '好', '是'],\n",
    "               ['老年', '是', '否', '非常好', '是'],\n",
    "               ['老年', '否', '否', '一般', '否']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datalabels = np.array(['年龄', '有工作', '有自己的房子', '信贷情况', '类别'])\n",
    "train_data = pd.DataFrame(datasets, columns=datalabels)\n",
    "test_data = ['老年', '否', '否', '一般']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt = DTree(epsilon=0) \n",
    "dt.fit(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('DTree:')\n",
    "printnode(dt.tree)\n",
    "y = dt.tree.predict(test_data)\n",
    "print('result:', y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt.cut(alpha=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('DTree:')\n",
    "printnode(dt.tree)\n",
    "y = dt.tree.predict(test_data)\n",
    "print('result:', y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "剪枝效果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets1 = np.array([['青年', '否', '否', '一般', '否'],\n",
    "               ['青年', '否', '否', '好', '否'],\n",
    "               ['青年', '是', '否', '好', '是'],\n",
    "               ['青年', '是', '是', '一般', '是'],\n",
    "               ['青年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '好', '否'],\n",
    "               ['中年', '是', '是', '好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '好', '是'],\n",
    "               ['老年', '是', '否', '好', '是'],\n",
    "               ['老年', '是', '否', '非常好', '是'],\n",
    "               ['老年', '否', '否', '一般', '否'],\n",
    "               ['青年', '否', '否', '一般', '是']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datalabels = np.array(['年龄', '有工作', '有自己的房子', '信贷情况', '类别'])\n",
    "train_data = pd.DataFrame(datasets1, columns=datalabels)\n",
    "test_data = ['老年', '否', '否', '一般']\n",
    "\n",
    "dt = DTree(epsilon=0) \n",
    "dt.fit(train_data)\n",
    "\n",
    "print('DTree:')\n",
    "printnode(dt.tree)\n",
    "y = dt.tree.predict(test_data)\n",
    "print('result:', y)\n",
    "print('----------------')\n",
    "dt.cut(alpha=0.5)\n",
    "\n",
    "print('----------------')\n",
    "print('DTree:')\n",
    "printnode(dt.tree)\n",
    "y = dt.tree.predict(test_data)\n",
    "print('result:', y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
